# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from dataclasses import dataclass
from itertools import chain
from typing import Any, Dict, List, Mapping

from promptflow.contracts.run_info import FlowRunInfo, RunInfo, Status


@dataclass
class LineResult:
    """The result of a line process."""

    output: Mapping[str, Any]  # The output of the line.
    # The node output values to be used as aggregation inputs, if no aggregation node, it will be empty.
    aggregation_inputs: Mapping[str, Any]
    run_info: FlowRunInfo  # The run info of the line.
    node_run_infos: Mapping[str, RunInfo]  # The run info of the nodes in the line.


@dataclass
class AggregationResult:
    """The result when running aggregation nodes in the flow."""

    output: Mapping[str, Any]  # The output of the aggregation nodes in the flow.
    metrics: Dict[str, Any]  # The metrics generated by the aggregation.
    node_run_infos: Mapping[str, RunInfo]  # The run info of the aggregation nodes.


@dataclass
class BulkResult:
    """The result of a bulk run."""

    outputs: List[Mapping[str, Any]]
    metrics: Mapping[str, Any]
    line_results: List[LineResult]
    aggr_results: AggregationResult

    def _get_line_run_infos(self):
        return (
            node_run_info for line_result in self.line_results for node_run_info in line_result.node_run_infos.values()
        )

    def _get_aggr_run_infos(self):
        return (node_run_info for node_run_info in self.aggr_results.node_run_infos.values())

    def get_status_summary(self):
        line_run_infos = self._get_line_run_infos()
        aggr_run_infos = self._get_aggr_run_infos()
        status_summary = {}
        line_status = {}
        for run_info in line_run_infos:
            if run_info.index not in line_status.keys():
                line_status[run_info.index] = True

            line_status[run_info.index] = line_status[run_info.index] and run_info.status in (
                Status.Completed,
                Status.Bypassed,
            )

            node_name = run_info.node
            # Only consider Completed, Bypassed and Failed status, because the UX only support three status.
            if run_info.status in (Status.Completed, Status.Bypassed, Status.Failed):
                node_status_key = f"__pf__.nodes.{node_name}.{run_info.status.value.lower()}"
                status_summary[node_status_key] = status_summary.setdefault(node_status_key, 0) + 1

        for run_info in aggr_run_infos:
            node_name = run_info.node
            status_summary[f"__pf__.nodes.{node_name}.completed"] = 1 if run_info.status == Status.Completed else 0

        status_summary["__pf__.lines.completed"] = sum(line_status.values())
        status_summary["__pf__.lines.failed"] = len(line_status) - status_summary["__pf__.lines.completed"]
        return status_summary

    def get_openai_metrics(self):
        node_run_infos = chain(self._get_line_run_infos(), self._get_aggr_run_infos())
        total_metrics = {}
        for run_info in node_run_infos:
            for call in run_info.api_calls:
                metrics = self._get_openai_metrics(call)
                self._merge_metrics_dict(total_metrics, metrics)
        return total_metrics

    def _get_openai_metrics(self, api_call: dict):
        total_metrics = {}
        if self._need_collect_metrics(api_call):
            metrics = api_call["output"]["usage"]
            self._merge_metrics_dict(total_metrics, metrics)

        children = api_call.get("children")
        if children is not None:
            for child in children:
                child_metrics = self._get_openai_metrics(child)
                self._merge_metrics_dict(total_metrics, child_metrics)

        return total_metrics

    def _need_collect_metrics(self, api_call: dict):
        if api_call.get("type") != "LLM":
            return False
        output = api_call.get("output")
        if not isinstance(output, dict):
            return False
        usage = output.get("usage")
        if not isinstance(usage, dict):
            return False
        return True

    def _merge_metrics_dict(self, metrics: dict, metrics_to_merge: dict):
        for k, v in metrics_to_merge.items():
            metrics[k] = metrics.get(k, 0) + v
